# -*- coding: UTF-8 -*-
import numpy as np
import copy

from agents.agent import Agent


class SFQL_CLUSTER(Agent):
    
    def __init__(self, lookup_table, *args, num_clusters=4, use_gpi=True, **kwargs):
        """
        Creates a new tabular successor feature agent.
        
        Parameters
        ----------
        lookup_table : TabularSF
            a tabular successor feature representation
        use_gpi : boolean
            whether or not to use transfer learning (defaults to True)
        """
        super(SFQL_CLUSTER, self).__init__(*args, **kwargs)
        self.sf = lookup_table
        self.use_gpi = use_gpi
        self.num_clusters = num_clusters
        self.transitions = {}

    def update_centroids(self, transitions, policy_index):
        for state, action, phi, next_state, next_action, gamma in transitions:
            psi = self.psi[policy_index]
            targets = phi.flatten() + gamma * psi[next_state][next_action,:] 
            errors = targets - psi[state][action,:]
            psi[state][action,:] = psi[state][action,:] + self.alpha * errors
        
    def get_Q_values(self, s, s_enc):
        q, self.c = self.sf.GPI(s_enc, self.task_index, update_counters=self.use_gpi)
        if not self.use_gpi:
            self.c = self.task_index
        return q[:, self.c,:]
    
    def train_agent(self, s, s_enc, a, r, s1, s1_enc, gamma):

        # record visitation, if new
        if s not in self.transitions:
            self.transitions[s]=[]
        if a not in self.transitions[s]:
            self.transitions[s].append(a)
        
        # update w
        t = self.task_index
        phi = self.phi(s, a, s1)
        self.sf.update_reward(phi, r, t)
        
        # update SF for the current task t
        if self.use_gpi:
            q1, _ = self.sf.GPI(s1_enc, t)
            q1 = np.max(q1[0,:,:], axis=0)
        else:
            q1 = self.sf.GPE(s1_enc, t, t)[0,:]
        next_action = np.argmax(q1)
        transitions = [(s_enc, a, phi, s1_enc, next_action, gamma)]
        self.sf.update_successor(transitions, t)
        
        # update SF for source task c
        if self.c != t:
            q1 = self.sf.GPE(s1_enc, self.c, self.c)
            next_action = np.argmax(q1)
            transitions = [(s_enc, a, phi, s1_enc, next_action, gamma)]
            self.sf.update_successor(transitions, self.c)

    # clustering SFs given centroids
    def cluster(self, mu, X):
        cluster_idx = []
        for x in X:
            for i, m in enumerate(mu):
                dist_xmu = np.linalg.norm( m - x )
                if i==0:
                    temp_dist = dist_xmu
                    temp_h = i
                else:
                    if dist_xmu < temp_dist:
                        temp_dist = dist_xmu
                        temp_h = i
            cluster_idx.append(copy.deepcopy(temp_h))
        
        return np.array(cluster_idx)
    


    def update_centroid(self, X_, sf ):

        # get initial clusters
        cluster_x = self.cluster(self.mu, X_)
        # sample SF
        i = np.random.choice(X_.shape[0])
        # pick the cluster idx of SF
        h = cluster_x[i]
        # update cluster centroid corresponding to h
        self.mu[h] = self.mu[h] + 2* self.alpha_mu / np.sqrt(self.v_mu[h]) * (X_[i] - self.mu[h])
        # update centroid update counter
        self.v_mu[h] = self.v_mu[h] + 1

    def reset(self):
        super(SFQL_CLUSTER, self).reset()
        self.sf.reset()
        
    def add_training_task(self, task):
        super(SFQL_CLUSTER, self).add_training_task(task)
        self.sf.add_training_task(task, -1)
        #  init cluster centroids
        n_features = task.feature_dim()
        self.mu = self.noise_init((self.num_clusters, n_features))
        # init learning rate for SA
        self.alpha_mu = 0.1
        # initialize counter of centroid updates
        self.v_mu = np.ones((self.num_clusters,))
    
    def get_progress_strings(self):
        sample_str, reward_str = super(SFQL_CLUSTER, self).get_progress_strings()
        gpi_percent = self.sf.GPI_usage_percent(self.task_index)
        w_error = np.linalg.norm(self.sf.fit_w[self.task_index] - self.sf.true_w[self.task_index])
        gpi_str = 'GPI% \t {:.4f} \t w_err \t {:.4f}'.format(gpi_percent, w_error)
        return sample_str, reward_str, gpi_str
    
